import torch
import torch.nn as nn

# 将x 最后 一个维度线性变换成指定维度
layer = nn.Linear(in_features=128, out_features=1)
x = torch.rand((32, 128))
y = layer(x)

layer = nn.Flatten(start_dim=1, end_dim=3)
x = torch.rand((32, 3, 128, 128))
y = layer(x)  # torch.Size([32, 49152])

print("完结撒花！")
